"""
Main file for training Yolo model on Pascal VOC dataset
"""
import os
import argparse
import torch
import torch.optim as optim
import torchvision.transforms as transforms

from tqdm import tqdm
from torch.utils.data import DataLoader

from model.model import Yolov1
from model.loss import YoloLoss
from model.dataset import VOCDataset
from model.utils import (
    get_bboxes,
    mean_average_precision,
    save_checkpoint,
    load_checkpoint,
)

seed = 123
torch.manual_seed(seed)

# Hyperparameters etc.
WEIGHT_DECAY = 0
NUM_WORKERS = 2
PIN_MEMORY = True
LAST_MODEL_FILE = "last.pth.tar"


class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, bboxes):
        for t in self.transforms:
            img, bboxes = t(img), bboxes

        return img, bboxes


transform = Compose(
    [
        transforms.Resize((448, 448)),
        transforms.ToTensor(),
    ]
)


def train_fn(train_loader, model, optimizer, loss_fn, device):
    loop = tqdm(train_loader, leave=True)
    mean_loss = []

    for batch_idx, (x, y) in enumerate(loop):
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = loss_fn(out, y)
        mean_loss.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update progress bar
        loop.set_postfix(loss=loss.item())

    print(f"Mean loss was {sum(mean_loss)/len(mean_loss)}")


def main(args):
    model = Yolov1(split_size=7, num_boxes=2, num_classes=20).to(args.device)
    optimizer = optim.Adam(
        model.parameters(), lr=args.learning_rate, weight_decay=WEIGHT_DECAY
    )
    loss_fn = YoloLoss()

    if (not args.renew) and os.path.exists(LAST_MODEL_FILE):
        load_checkpoint(torch.load(LAST_MODEL_FILE), model, optimizer)

    train_dataset = VOCDataset(
        "data/train.txt",
        transform=transform,
    )

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        shuffle=True,
        drop_last=True,
    )

    for epoch in range(args.epochs):
        train_fn(train_loader, model, optimizer, loss_fn, args.device)

        pred_boxes, target_boxes = get_bboxes(
            train_loader, model, iou_threshold=0.5, threshold=0.4, device=args.device
        )

        mean_avg_prec = mean_average_precision(
            pred_boxes, target_boxes, iou_threshold=0.5, box_format="midpoint"
        )
        print(f"Epoches: {epoch+1}/ {args.epochs}, Train mAP: {mean_avg_prec}")

    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    save_checkpoint(checkpoint, filename=LAST_MODEL_FILE)


def parse_args():
    parser = argparse.ArgumentParser(description="Yolo-v1 train")

    # fmt: off
    parser.add_argument("--batch-size", type=int, default=16, help="batch size")
    parser.add_argument("--epochs", type=int, default=10, help="training epochs")
    parser.add_argument("--learning-rate", type=float, default=1e-3, help="learning rate")
    parser.add_argument("--renew", action="store_true", help="training from beginning")
    parser.add_argument('--device', default='cuda:0', help='Device used for inference')
    # fmt: on

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    main(args)
